"""Providing Inputs to tfe models.  For internal use in Sequential only."""
import numpy as np
from tensorflow.keras import backend
from tensorflow.python.framework import tensor_shape

import tf_encrypted as tfe
from tf_encrypted.keras.engine.base_layer import Layer


class InputLayer(Layer):
    """Layer to be used as an entry point into a Network (a graph of layers).
    It can either wrap an existing tensor (pass an `input_tensor` argument)
    or create its a placeholder tensor (pass arguments `input_shape`, and
    optionally, `dtype`).
    It is generally recommend to use the functional layer API via `Input`,
    (which creates an `InputLayer`) without directly using `InputLayer`.
    Arguments:
        input_shape: Shape tuple (not including the batch axis), or `TensorShape`
          instance (not including the batch axis).
        batch_size: Optional input batch size (integer or None).
        dtype: Datatype of the input.
        input_tensor: Optional tensor to use as layer input
            instead of creating a placeholder.
        sparse: Boolean, whether the placeholder created
            is meant to be sparse.
        name: Name of the layer (string).
    """

    def __init__(
        self,
        input_shape=None,
        batch_size=None,
        dtype=None,
        input_tensor=None,
        sparse=False,
        name=None,
        **kwargs,
    ):
        if "batch_input_shape" in kwargs:
            batch_input_shape = kwargs.pop("batch_input_shape")
            if input_shape and batch_input_shape:
                raise ValueError(
                    "Only provide the input_shape OR "
                    "batch_input_shape argument to "
                    "InputLayer, not both at the same time."
                )
            batch_size = batch_input_shape[0]
            input_shape = batch_input_shape[1:]
        if kwargs:
            raise ValueError("Unrecognized keyword arguments:", kwargs.keys())

        if not name:
            prefix = "input"
            name = prefix + "_" + str(backend.get_uid(prefix))

        if batch_size is None:
            raise NotImplementedError()
        if input_tensor is not None:
            raise NotImplementedError()
        if dtype is not None:
            raise NotImplementedError()
        if sparse:
            raise NotImplementedError()

        super(InputLayer, self).__init__()
        self.built = True
        self.batch_size = batch_size

        if isinstance(input_shape, tensor_shape.TensorShape):
            input_shape = tuple(input_shape.as_list())
        elif isinstance(input_shape, int):
            input_shape = (input_shape,)

        if input_shape is not None:
            self._batch_input_shape = (batch_size,) + tuple(input_shape)
        else:
            raise ValueError("Input shape must be defined for the first layer.")

        # Create a graph placeholder to call the layer on.
        self.input = tfe.define_private_variable(np.zeros(self._batch_input_shape))


def Input(  # pylint: disable=invalid-name
    shape=None,
    batch_size=None,
    name=None,
    dtype=None,
    sparse=False,
    tensor=None,
    **kwargs,
):
    """`Input()` is used to instantiate a Keras tensor.
    A Keras tensor is a tensor object from the underlying backend
    (TF Encrypted), which we augment with certain
    attributes that allow us to build a Keras model.

    Arguments:
        shape: A shape tuple (integers), not including the batch size.
            For instance, `shape=(32,)` indicates that the expected input
            will be batches of 32-dimensional vectors.
        batch_size: optional static batch size (integer).
        name: An optional name string for the layer.
            Should be unique in a model (do not reuse the same name twice).
            It will be autogenerated if it isn't provided.
        dtype: The data type expected by the input, as a string
            (`float32`, `float64`, `int32`...)
        sparse: A boolean specifying whether the placeholder
            to be created is sparse.
        tensor: Optional existing tensor to wrap into the `Input` layer.
            If set, the layer will not create a placeholder tensor.
        **kwargs: deprecated arguments support.

    Returns:
      A `tensor`.

    Example:
    ```python
    # this is a logistic regression in Keras
    x = Input(shape=(32,))
    y = Dense(16, activation='softmax')(x)
    model = Model(x, y)
    ```
    Note that even if eager execution is enabled,
    `Input` produces a symbolic tensor (i.e. a placeholder).
    This symbolic tensor can be used with other
    TensorFlow ops, as such:
    ```python
    x = Input(shape=(32,))
    y = tf.square(x)
    ```

    Raises:
      ValueError: in case of invalid arguments.
    """
    batch_shape = None
    if "batch_shape" in kwargs:
        batch_shape = kwargs.pop("batch_shape")
        if shape and batch_shape:
            raise ValueError(
                "Only provide the shape OR "
                "batch_shape argument to "
                "Input, not both at the same time."
            )
        batch_size = batch_shape[0]
        shape = batch_shape[1:]
    if kwargs:
        raise ValueError("Unrecognized keyword arguments:", kwargs.keys())

    if shape is None and tensor is None:
        raise ValueError(
            "Please provide to Input either a `shape`"
            " or a `tensor` argument. Note that "
            "`shape` does not include the batch "
            "dimension."
        )

    if sparse:
        raise NotImplementedError()
    if dtype is not None:
        raise NotImplementedError()
    if tensor is not None:
        raise NotImplementedError()

    if batch_shape:
        input_layer = InputLayer(
            batch_input_shape=batch_shape,
            name=name,
            dtype=dtype,
            sparse=sparse,
            input_tensor=tensor,
        )
    else:
        input_layer = InputLayer(
            input_shape=shape,
            batch_size=batch_size,
            name=name,
            dtype=dtype,
            sparse=sparse,
            input_tensor=tensor,
        )

    # Return tensor including `_keras_history`.
    # Note that in this case train_output and test_output are the same pointer.
    return input_layer.input
